{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.2 Strassen's algorithm for matrix multiplication"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2-1\n",
    "\n",
    "> Use Strassen’s algorithm to compute the matrix product\n",
    ">\n",
    "> $\\begin{pmatrix} 1 & 3 \\\\ 7 & 5 \\end{pmatrix}\\begin{pmatrix} 6 & 8 \\\\ 4 & 2 \\end{pmatrix}$.\n",
    ">\n",
    "> Show your work."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$S_1 = B_{12} - B_{22} = 8 - 2 = 6$\n",
    "\n",
    "$S_2 = A_{11} + A_{12} = 1 + 3 = 4$\n",
    "\n",
    "$S_3 = A_{21} + A_{22} = 7 + 5 = 12$\n",
    "\n",
    "$S_4 = B_{21} - B_{11} = 4 - 6 = -2$\n",
    "\n",
    "$S_5 = A_{11} + A_{22} = 1 + 5 = 6$\n",
    "\n",
    "$S_6 = B_{11} + B_{22} = 6 + 2 = 8$\n",
    "\n",
    "$S_7 = A_{12} - A_{22} = 3 - 5 = -2$\n",
    "\n",
    "$S_8 = B_{21} + B_{22} = 4 + 2 = 6$\n",
    "\n",
    "$S_9 = A_{11} - A_{21} = 1 - 7 = -6$\n",
    "\n",
    "$S_{10} = B_{11} + B_{12} = 6 + 8 = 14$\n",
    "\n",
    "$P_1 = A_{11} \\cdot S_1 = 1 \\times 6 = 6$\n",
    "\n",
    "$P_2 = S_{2} \\cdot B_{22} = 4 \\times 2 = 8$\n",
    "\n",
    "$P_3 = S_{3} \\cdot B_{11} = 12 \\times 6 = 72$\n",
    "\n",
    "$P_4 = A_{22} \\cdot S_4 = 5 \\times -2 = -10$\n",
    "\n",
    "$P_5 = S_{5} \\cdot S_6 = 6 \\times 8 = 48$\n",
    "\n",
    "$P_6 = S_{7} \\cdot S_8 = -2 \\times 6 = -12$\n",
    "\n",
    "$P_7 = S_{9} \\cdot S_{10} = -6 \\times 14 = -84$\n",
    "\n",
    "$C_{11} = P_5 + P_4 - P_2 + P_6 = 48 - 10 - 8 - 12 = 18$\n",
    "\n",
    "$C_{12} = P_1 + P_2 = 8 + 6 = 14$\n",
    "\n",
    "$C_{21} = P_3 + P_4 = 72 - 10 = 62$\n",
    "\n",
    "$C_{22} = P_5 + P_1 - P_3 - P_7 = 48 + 6 - 72 + 84 = 66$\n",
    "\n",
    "$\\begin{pmatrix} 18 & 14 \\\\ 62 & 66 \\end{pmatrix}$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2-2\n",
    "\n",
    "> Write pseudocode for Strassen’s algorithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "def matrix_product_strassen_sub(a, b, r_low, r_high, c_low, c_high):\n",
    "    n = r_high - r_low\n",
    "    if n == 1:\n",
    "        return [[a[r_low][c_low] * b[r_low][c_low]]]\n",
    "    mid = n // 2\n",
    "    r_mid = (r_low + r_high) // 2\n",
    "    c_mid = (c_low + c_high) // 2\n",
    "    s = [[[0 for _ in range(mid)] for _ in range(mid)] for _ in range(10)]\n",
    "    for i in range(mid):\n",
    "        for j in range(mid):\n",
    "            s[0][i][j] = b[r_low + i][c_mid + j] - b[r_mid + i][c_mid + j]\n",
    "            s[1][i][j] = a[r_low + i][c_low + j] + a[r_low + i][c_mid + j]\n",
    "            s[2][i][j] = a[r_mid + i][c_low + j] + a[r_mid + i][c_mid + j]\n",
    "            s[3][i][j] = b[r_mid + i][c_low + j] - b[r_low + i][c_low + j]\n",
    "            s[4][i][j] = a[r_low + i][c_low + j] + a[r_mid + i][c_mid + j]\n",
    "            s[5][i][j] = b[r_low + i][c_low + j] + b[r_mid + i][c_mid + j]\n",
    "            s[6][i][j] = a[r_low + i][c_mid + j] - a[r_mid + i][c_mid + j]\n",
    "            s[7][i][j] = b[r_mid + i][c_low + j] + b[r_mid + i][c_mid + j]\n",
    "            s[8][i][j] = a[r_low + i][c_low + j] - a[r_mid + i][c_low + j]\n",
    "            s[9][i][j] = b[r_low + i][c_low + j] + b[r_low + i][c_mid + j]\n",
    "    p = [[[0 for _ in range(mid)] for _ in range(mid)] for _ in range(7)]\n",
    "    for i in range(mid):\n",
    "        for j in range(mid):\n",
    "            for k in range(mid):\n",
    "                p[0][i][j] += a[r_low + i][c_low + k] * s[0][k][j]\n",
    "                p[1][i][j] += s[1][i][k] * b[r_mid + k][c_mid + j]\n",
    "                p[2][i][j] += s[2][i][k] * b[r_low + k][c_low + j]\n",
    "                p[3][i][j] += a[r_mid + i][c_mid + k] * s[3][k][j]\n",
    "                p[4][i][j] += s[4][i][k] * s[5][k][j]\n",
    "                p[5][i][j] += s[6][i][k] * s[7][k][j]\n",
    "                p[6][i][j] += s[8][i][k] * s[9][k][j]\n",
    "    c = [[0 for _ in range(n)] for _ in range(n)]\n",
    "    for i in range(mid):\n",
    "        for j in range(mid):\n",
    "            c[r_low + i][c_low + j] = p[4][i][j] + p[3][i][j] - p[1][i][j] + p[5][i][j]\n",
    "            c[r_low + i][c_mid + j] = p[0][i][j] + p[1][i][j]\n",
    "            c[r_mid + i][c_low + j] = p[2][i][j] + p[3][i][j]\n",
    "            c[r_mid + i][c_mid + j] = p[4][i][j] + p[0][i][j] - p[2][i][j] - p[6][i][j]\n",
    "    return c\n",
    "\n",
    "\n",
    "def matrix_product_strassen(a, b):\n",
    "    n = len(a)\n",
    "    return matrix_product_strassen_sub(a, b, 0, n, 0, n)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2-3\n",
    "\n",
    "> How would you modify Strassen’s algorithm to multiply $n \\times n$ matrices in which $n$ is not an exact power of $2$? Show that the resulting algorithm runs in time $\\Theta(n^{\\lg 7})$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Extend the matrix with zeros."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2-4\n",
    "\n",
    "> What is the largest $k$ such that if you can multiply $3 \\times 3$ matrices using k multiplications (not assuming commutativity of multiplication), then you can multiply $n \\times n$ matrices in time $\\Theta(n^{\\lg 7})$? What would the running time of this algorithm be?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$T(n) = kT(n/3) + O(n^2)$\n",
    "\n",
    "Running time: $\\Theta(n^{\\log_3 7})$\n",
    "\n",
    "$k \\le 3^{\\lg 7} \\approx 21.84986222490514$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2-5\n",
    "\n",
    "> V. Pan has discovered a way of multiplying $68 \\times 68$ matrices using $132,464$ multiplications, a way of multiplying $70 \\times 70$ matrices using $143,640$ multiplications, and a way of multiplying $72 \\times 72$ matrices using $155,424$ multiplications. Which method yields the best asymptotic  matrix-multiplication algorithm? How does it compare to Strassen’s algorithm?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$T_1 = \\Theta(n^{\\log_{68}132464}) = \\Theta(n^{2.7951284873613815})$\n",
    "\n",
    "$T_2 = \\Theta(n^{\\log_{70}143640}) = \\Theta(n^{2.795122689748337})$\n",
    "\n",
    "$T_3 = \\Theta(n^{\\log_{72}155424}) = \\Theta(n^{2.795147391093449})$\n",
    "\n",
    "$T_2 < T_1 < T_3 < \\Theta(n^{\\lg 7})$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2-6\n",
    "\n",
    "> How quickly can you multiply a $kn \\times n$ matrix by an $n \\times kn$ matrix, using Strassen’s algorithm as a subroutine? Answer the same question with the order of the input matrices reversed."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$\\Theta(k^2n^{\\lg 7})$\n",
    "\n",
    "Reversed: $\\Theta(kn^{\\lg 7})$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2-7\n",
    "\n",
    "> Show how to multiply the complex numbers $a + bi$ and $c + di$ using only three multiplications of real numbers. The algorithm should take $a$, $b$, $c$, and $d$ as input and produce the real component $ac - bd$ and the imaginary component $ad + bc$ separately."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$P_1 = a \\cdot (c - d)$\n",
    "\n",
    "$P_2 = b \\cdot (c + d)$\n",
    "\n",
    "$P_3 = d \\cdot (a - b)$\n",
    "\n",
    "Real component: $P_1 + P_3$\n",
    "\n",
    "Image component: $P_2 + P_3$"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
